﻿#include "WebsocketRespond.h"

#include <iostream>
#include <map>
#include <sstream>

#include "sha1.h"
#include "base64.h"


#if defined(WIN32)
#define WIN32_LEAN_AND_MEAN
#include <windows.h>

#include <Winsock2.h>
#include <WS2tcpip.h>
#include <ws2bth.h>
#endif

#if defined(LINUX)
#include <sys/types.h>          /* See NOTES */
#include <sys/socket.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#endif


#define MAGIC_KEY "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"

std::map<std::string, std::string> header_map_;

WebsocketRespond::WebsocketRespond()
{
	header_map_.clear();
}
WebsocketRespond::~WebsocketRespond()
{
	header_map_.clear();
}


int WebsocketRespond::send_data(int fd, char *buff, int len)
{
	return  send(fd, buff, len, 0);
}

int WebsocketRespond::handshark(int fd, const char *buf, unsigned int len)
{
    fetch_http_info(buf, len);

	std::string msg;
	msg.append("HTTP/1.1 101 Switching Protocols\r\n");
	msg.append("Connection: Upgrade\r\n");
	msg.append("Sec-WebSocket-Accept: ");

	std::string server_key = header_map_["Sec-WebSocket-Key"];
	server_key += MAGIC_KEY;

	SHA1 sha;
	unsigned int message_digest[5];
	sha.Reset();
	sha << server_key.c_str();

	sha.Result(message_digest);
	for (int i = 0; i < 5; i++) {
		message_digest[i] = htonl(message_digest[i]);
	}
	server_key = base64_encode(reinterpret_cast<const unsigned char*>(message_digest), 20);
	server_key += "\r\n";

	msg += server_key;
	msg.append("Upgrade: websocket\r\n\r\n");

	send(fd, msg.c_str(), msg.size(), 0);

    return 0;
}

std::string Int_to_String(int n)
{
	std::ostringstream stream;
	stream << n;  //n为int类型
	return stream.str();
}

int WebsocketRespond::fetch_http_info(const char *buf, unsigned int len)
{
    std::string _buf(buf, len);
    std::istringstream s(_buf);
    std::string request;

	std::getline(s, request);
    if(request[request.size() - 1] == '\r')
    {
        request.erase(request.end() - 1);
    }
    else
    {
        return -1;
    }
    printf("request %s \n", request.c_str());


    std::string header;
    std::string::size_type end;

    while(std::getline(s, header) && header != "\r")
    {
        if(header[header.size() - 1] != '\r')
        {
            continue; //end
        }
        else
        {
            header.erase(header.end() - 1); //remove last char
        }

        printf("header %s \n", header.c_str());
        end = header.find(": ", 0);

        if(end != std::string::npos)
        {
            std::string key = header.substr(0, end);
            std::string value = header.substr(end + 2);
			header_map_[key] = value;
        }
    }

    return 0;
}
